# clean workspace
rm(list = ls())
# clean garbadge
gc()
# clear graphics device
# dev.off()
# set decimals to digits instead of scientific
options(scipen = 999)
# set timeout limit (laxed)
options(timeout = 10000)
# set work directory
setwd(dir = "~/Desktop/Mechanical Pollster/Replication_Material_PS/")

# # # 
# # # load utils - Make sure these are installed before proceeding with the replication 
# # # 

library(data.table)
library(questionr)
library(tools)

# Sample function is useful but buggy -
# if you specify a single integer it returns a sequence up to that integer
sample = function(x, size, replace = F, prob = NULL) {
  if (length(x) == 1) return(x)
  base::sample(x, size = size, replace = replace, prob = prob)
}

# map function which works on both vectors and sims matrices
library(bayestestR)
map <- function(x){ 
  if(!is.null(dim(x))){
    apply(x,2,function(x){map_estimate(x)$MAP_Estimate})
  }else{
    map_estimate(x)$MAP_Estimate
  } }

# plotting utils
library(vioplot)
# cols
cols.light <- c(D = 'skyblue',G = 'green',K = 'violet',L = 'yellow',R = 'lightcoral',`stay home` = 'darkgrey',W = 'pink',OTHER = 'darkgrey',favourable = 'springgreen',unfavourable = 'orangered',`no opinion` = 'darkgrey',net = 'lavenderblush')
cols.dark <-  c(D = 'dodgerblue',G = 'darkgreen',K = 'purple',L = 'orange',R = 'orangered',`stay home` = 'black',W = 'darkred',OTHER = 'black',favourable = 'springgreen4',unfavourable = 'orangered4',`no opinion` = 'black',net = 'lavenderblush4')

# plot function 
plot.effect_violin <- 
  function(x,var.levels,cols.light,cols.dark,reference = NULL,...){
    # select the effect of interest 
    
    x <- as.matrix(x)
    
    effects <- 
      lapply(
        1:dim(x)[2],
        function(k){
          x[,k]
        }
      )
    
    # plot violins
    vioplot(
      effects,
      names = var.levels, 
      col = cols.light,
      side = "right",
      ...
    )
    # add reference
    if(!is.null(reference)){
      abline(h = reference,lty = 2)
    }
  }

# I'm going to use JAGS to do some basic modeling 
library(R2jags)

# load crosstabs from polls
load(file = 'crosstabs_polls.RData')

# last minute cleaning of cross data
cross$sponsor <- ifelse(is.na(cross$sponsor),'No',cross$sponsor)
idvars <- c('pollster','sponsor','start.date','end.date','population','variable','condition')

# identify dependent variable 
dep.var <- 'vote2024'

# we don't care for plotting third parties or the alternative PoSSUM estimates 
THIRD_PARTIES = FALSE
INCLUDE_ALL_PoSSUM = FALSE

# gather dates dates 
pred.files <- dir()
dates <- gsub('crosstabs_','',gsub(".RData", "",pred.files[grepl('crosstabs_2024',pred.files)]))

# loop over the desired conditions 
for(crosstab.condition in c('population','state')){
  # loop over relevant date ranges 
  for(date in dates){
    # # # 
    # # # 
    # # # LOAD AND CLEAN OUR PREDICTIONS 
    # # # 
    # # # 
    
    # load predicted values
    load(file = paste0('crosstabs_',date,'.RData',sep = ''))
    # append the special districts to the `state` results 
    crosstab.pred_list_spec.High <- crosstab.pred_list
    if(crosstab.condition == 'state'){
      tmp <- 
        crosstab.pred_list_spec.High$state_simple[
          crosstab.pred_list_spec.High$state_simple$state_simple %in% 
            c('Nebraska','Maine')
        ]
      names(tmp)[names(tmp) %in% 'state_simple'] <- 'state_electoral.college'
      crosstab.pred_list_spec.High$state_electoral.college <- 
        rbindlist(
          list(
            crosstab.pred_list_spec.High$state_electoral.college,
            tmp
          )
        )
    }
    
    # # # 
    # # # 
    # # # LOAD AND CLEAN THE COMPARISON CROSSTABS FROM OTHER POLLSTERS 
    # # # 
    # # # 
    
    # get the crosstab distribution of interest 
    crosstab.marginal <-
      reshape(
        cross[variable == dep.var & grepl(crosstab.condition,condition)],
        direction = 'wide',
        timevar = 'response',
        idvar = idvars
      )
    
    # align time frame with PoSSUM (for this particular poll)
    keep <- 
      which(
        crosstab.marginal$start.date<=  as.Date(strsplit(date,split = '_')[[1]][2],'%Y-%m-%d') &
          crosstab.marginal$end.date>= (as.Date(strsplit(date,split = '_')[[1]][1],'%Y-%m-%d')) # up to 2 weeks earlier
      )
    crosstab.marginal <- crosstab.marginal[keep]
    
    # Delaware wasn't yet available at the time of generating the plots for the initial PS submission, so drop it for perfect replication 
    if(any(crosstab.marginal$condition == 'state_delaware')){
      crosstab.marginal <- crosstab.marginal[condition != 'state_delaware']
    }
    
    # remove stay home - make this a LV posterior
    if(any(grepl('stay home',names(crosstab.marginal)))){
      crosstab.marginal <- crosstab.marginal[,!grepl('stay home',names(crosstab.marginal)),with=F]
    }
    
    # aggregate G W L into `OTHER`
    crosstab.marginal$N.OTHER <- apply(crosstab.marginal[,grepl('G|W|L|OTHER',names(crosstab.marginal)),with=F],1,function(x){sum(x,na.rm=T)})
    crosstab.marginal <- crosstab.marginal[,!grepl('G|W|L',names(crosstab.marginal)),with=F]
    crosstab.marginal$N.OTHER <- ifelse(crosstab.marginal$N.OTHER ==0,NA,crosstab.marginal$N.OTHER)
    crosstab.marginal <- crosstab.marginal[,lapply(.SD,function(x){as.numeric(ifelse(is.na(x),NA,x))}),by = c(idvars)] 
  
  
  # # # 
  # # # 
  # # # MAKE DATA INTO FORMAT AMEANABLE TO ANALYSIS BY JAGS 
  # # # 
  # # # 
  
  # Create a set identifier by converting NA patterns to strings
  y.matrix <- crosstab.marginal[,grepl('N.',names(crosstab.marginal)),with=F]
  
  crosstab.marginal$set <- apply(y.matrix, 1, function(x) gsub('-NA','',paste0(gsub('N.','',names(y.matrix))[x!=''],collapse = '-') ) )
  crosstab.marginal$set <- as.factor(crosstab.marginal$set)
  crosstab.marginal$set_id <- as.integer(crosstab.marginal$set)
  crosstab.marginal$pollster <- as.factor(crosstab.marginal$pollster)
  crosstab.marginal$pollster_id <- as.integer(crosstab.marginal$pollster)
  crosstab.marginal$condition <- as.factor(crosstab.marginal$condition )
  crosstab.marginal$condition_id <- as.integer(crosstab.marginal$condition )
  
  # create id to connect to predicted values 
  crosstab.condition_id <- 
    ifelse(crosstab.condition=='population','date',
           ifelse(grepl('state',crosstab.condition),'state_electoral.college',
                  crosstab.condition
           ) ) 
  if(grepl('state',crosstab.condition)){
    crosstab.pred_list_spec.High[[crosstab.condition_id]] <- 
      crosstab.pred_list_spec.High[[crosstab.condition_id]][tolower(crosstab.pred_list_spec.High[[crosstab.condition_id]][[1]]) %in% gsub('state_','',levels(crosstab.marginal$condition))]
  }
  
  
  # aggregate `OTHER` in predictions 
  parties <- c('G','L','W')
  parties <- parties[parties %in% names(crosstab.pred_list_spec.High$date)]
  
  crosstab.pred_list_spec.High[[crosstab.condition_id]]$OTHER <- 
    rowSums(crosstab.pred_list_spec.High[[crosstab.condition_id ]][,..parties])
  
  
  # prepare jags data
  y <- as.matrix(crosstab.marginal[,grepl('N.',names(crosstab.marginal)),with=F])
  n = rowSums(y,na.rm=T)
  N = dim(y)[1]
  
  # get separate data to estimate the independent models in the same code 
  y_set_prime_j <- list()
  for(i in 1:N){
    y_set_prime_j[[i]] <- y[i,][!is.na(y[i,])]#y[i,which_j_set[[set_id[i]]]]
  }
  names(y_set_prime_j) <- paste0('y_set_prime_j',1:N)
  
  J_prime_labs <- lapply(y_set_prime_j,function(x){gsub('N.','',names(x))})
  names(J_prime_labs) <- paste0('J_prime_labs',1:N)
  
  J_prime <- lapply(y_set_prime_j,length)
  names(J_prime) <- paste0('J_prime',1:N)
  
  m <- lapply(y_set_prime_j,sum)
  names(m) <- paste0('m',1:N)
  
  # prepare unchangeable
  jags.data <-
    list(
      N = dim(crosstab.marginal)[1],
      n = n,
      J = dim(y)[2],
      J_labs = gsub('N.','',colnames(y)),
      
      pollster_id = crosstab.marginal$pollster_id,
      pollster_N = max( crosstab.marginal$pollster_id),
      pollster_labs = levels( crosstab.marginal$pollster),
      
      condition_id = crosstab.marginal$condition_id,
      condition_N = max(crosstab.marginal$condition_id),
      condition_labs = levels(crosstab.marginal$condition)
    )
  # append stuff for independent models
  jags.data <- append(jags.data,y_set_prime_j)
  jags.data <- append(jags.data,J_prime_labs)
  jags.data <- append(jags.data,J_prime)
  jags.data <- append(jags.data,m)
  
  # drop undefined
  if(any(sapply(jags.data,length)==0)){
    jags.data <- jags.data[-which(sapply(jags.data,length)==0)]
  }
  
  # now prepare the JAGS code dynamically 
  {
    jags_model_code <- "model{\n\n"
    
    # odel the vote share estimates independently
    for(i in 1:N) {
      # Likelihood
      jags_model_code <- paste0(jags_model_code, "y_set_prime_j",i," ~ dmulti(lambda",i,", m",i,");\n")
      jags_model_code <- paste0(jags_model_code, "lambda",i," ~ ddirch(nu",i,");\n\n")
      jags_model_code <- paste0(jags_model_code, "for(j in 1:J_prime",i,"){\n")
      jags_model_code <- paste0(jags_model_code, "  nu",i,"[j] <- 0.5 ;\n")
      jags_model_code <- paste0(jags_model_code, "}\n")
    }
    # Close the model block
    jags_model_code <- paste0(jags_model_code, "}")
  }
  
  tmpf=tempfile()
  tmps=file(tmpf,"w")
  cat(jags_model_code,file=tmps)
  close(tmps)
  
  cat(jags_model_code)
  
  params.to.store <-c('nu')
  for(i in 1:N){
    params.to.store <- c(params.to.store,paste0('lambda',i))
  }
  
  # fit the model
  jags.fit <-
    jags.parallel(
      data = jags.data,
      parameters.to.save = params.to.store,
      n.iter = 200000,
      model.file = tmpf,
      n.chains = 8,
      n.cluster = 8,
      n.burnin = 185000,
      n.thin = 8
    ) 
  
  # plot convergence (Rhat only)
  rhat <- jags.fit$BUGSoutput$summary[,'Rhat']
  plot(
    rhat,
    ylim = c(min(min(rhat),0.9),max(max(rhat),1.25)),
    main = 'Convergence',
    ylab = expression(hat(R))
  )
  abline(h = 1.05,lty = 1)
  
  # prepare labs for addition of possum and average in plot
  jags.data$pollster_labs <- c(jags.data$pollster_labs,'PoSSUM MrP')
  
  
  # # # state-level plot - swing states 
  if(grepl('state',crosstab.condition) & date == '2024-08-15_2024-09-12'){
    pdf(file = paste0('generated_plots/states.comprehensive_',date,'.pdf'),width = 10,height = 16.5, bg = "white")
    # Initialise list with simple estimates 
    par(mfrow = c(12,3),mar = c(3,3,3,1.5))
    for (c in levels(crosstab.marginal$condition)){
      parties = c('R','D')
      lambda.list <- list()
      for(j.id in parties){
        lambda.tmp <- vector("list", length(jags.data$pollster_labs)) 
        for (k in 1:length(jags.data$pollster_labs)) {
          if(k==(jags.data$pollster_N+1)){
            pred.condition_spec.High <- crosstab.pred_list_spec.High[[crosstab.condition_id]][[j.id]][tolower(crosstab.pred_list_spec.High[[crosstab.condition_id]][[1]]) == gsub('state_','',c)]
            lambda.tmp[[k]] <- as.matrix(c(pred.condition_spec.High,rep(NA,jags.fit$BUGSoutput$n.sims-length(pred.condition_spec.High))))
            names(lambda.tmp)[k] <- "PoSSUM MrP"
          }
          # for every pollster
          for (i in 1:N) {
            # go through each poll
            if (jags.data$pollster_id[i] != k | jags.data$condition_labs[jags.data$condition_id[i]]!=c) next 
            # if the poll belongs to this pollster
            # Check if the choice is among those in the choice set for the poll
            if (j.id %in% jags.data[[paste0('J_prime_labs', i)]]) {
              # if so, get the simulations and bind them in a matrix to other polls fro that pollster 
              lambda.tmp[[k]] <- cbind(
                lambda.tmp[[k]],
                jags.fit$BUGSoutput$sims.list[[paste0('lambda', i)]][, which(jags.data[[paste0('J_prime_labs', i)]] == j.id)]
              )
            }
          }
          # assign a name to the pollster in the list
          names(lambda.tmp)[k] <- paste(j.id,':',jags.data$pollster_labs[k])
        }
        lambda.list <- append(lambda.list,lambda.tmp)
      }
      
      # take the average per plloster and bind into a matrix - useful for ordering 
      pi_hat.tmp <- sapply(lambda.tmp,function(x){if(length(dim(x)[2])==0){rep(NA,jags.fit$BUGSoutput$n.sims)}else{rowMeans(x)}})
      
      # get map for each pollster
      pi_hat.point <- apply(pi_hat.tmp,2,function(x){if(all(is.na(x))){NA}else{ifelse(all(x==0|is.na(x)),0,map(x))}})
      
      # Calculate delta
      delta <-  
        100*lambda.list$`R : PoSSUM MrP`[!is.na(lambda.list$`R : PoSSUM MrP`)] -
        100*lambda.list$`D : PoSSUM MrP`[!is.na(lambda.list$`D : PoSSUM MrP`)]
      
      # Determine the range of delta and expand it slightly
      delta_range <- range(delta, na.rm = TRUE)
      delta_min <- delta_range[1] - 1  # Expand lower bound slightly
      delta_max <- delta_range[2] + 1  # Expand upper bound slightly
      
      # Define the desired bin width
      bin_width <- 1  # Set desired bin width
      
      # Create breaks that fully cover the range of delta
      breaks <- seq(delta_min, delta_max, by = bin_width)
      
      # Calculate histogram without plotting
      hist <- hist(
        x = delta,
        plot = FALSE,
        breaks = breaks
      )
      
      # Calculate the mean of delta > 0
      pr_r_greater_d <- mean(delta > 0)
      
      # Define color interpolation between dodgerblue and orangered
      color_interp <- colorRampPalette(c("dodgerblue", "orangered"))
      
      # Interpolate the color based on pr_r_greater_d
      # This returns a color where 0 -> dodgerblue, 1 -> orangered
      hist_color <- color_interp(100)[max(1, round(pr_r_greater_d * 100))]
      
      # Plot the histogram with the interpolated color
      plot(
        hist,
        main = toTitleCase(gsub('state_', '', c)),
        xlab = 'R-D',
        col = hist_color,border = NA,
        xlim = c(-35,35)
      )
      
      # Add vertical line at polling average 
      # take pollsters who fielded polls for this state
      polls.temp <- lambda.list[which(sapply(lambda.list,function(x){!is.null(x)}))]
      # drop possum from this
      polls.temp <- polls.temp[-grep('PoSSUM',names(polls.temp))]
      # polling average 
      polling.avg <- 
        sapply(parties,
               function(j){
                 mean(Reduce(cbind,polls.temp[grepl(j,names(polls.temp))]))
               } )
      polling.avg <- polling.avg['R'] - polling.avg['D']
      
      abline(v = 100*polling.avg, lty = 2, lwd = 1.75)
      abline(v = median(delta), lty = 1, lwd = 1.75)
      
      # Add a legend showing R-D value and probability
      if(map(delta)>0){legend.place = 'topleft'}else{legend.place = 'topright'}
      legend(
        legend.place,
        legend = c(
          paste0('Med.(R-D) = ', round(median(delta) , 1)),
          paste0('Polling Avg. = ',round(polling.avg*100 ,1)),
          paste0('Pr(R>D) = ', round(pr_r_greater_d, 1)),
          paste0('Bias = ', round(median(delta) - polling.avg*100, 1))
        ),
        lty = c(1, 2, NA,NA),
        lwd = c(1,1,NA,NA),
        cex = 1,
        bty = 'n'
      )
      
    }
    dev.off()
    
  }
  
  # # # topline numbers
  if(grepl('population',crosstab.condition) & date != '2024-08-15_2024-09-12' ){
    
    height = 10
    mfrow = c(2,2)

    if(THIRD_PARTIES == FALSE){height = 7.5; jags.data$J_labs <- c('D','R'); mfrow = c(1,2)}

    pdf(file = paste0('generated_plots/',dep.var,'_',crosstab.condition,'_',date,'.pdf'),width = 16.5,height = height, bg = "white")
    
    par(mfrow =mfrow,oma = c(5,0,3,0),mar = c(12,3,3,3)) 
    
    for(j.id in jags.data$J_labs ){
      
      # Assign candidate name based on choice name
        main = 
          ifelse(j.id == 'D','Kamala Harris',
                 ifelse(j.id == 'R','Donald Trump',
                        ifelse(j.id == 'K','Robert F. Kennedy Jr.','Other'
                        ) ) )
      
      # Initialise list with simple estimates 
      lambda.tmp <- vector("list", length(jags.data$pollster_labs)) 
      for (k in 1:length(jags.data$pollster_labs)) {
        
        if(k==(jags.data$pollster_N+1)){
          lambda.tmp[[k]] <- as.matrix(c(crosstab.pred_list_spec.High$date[[j.id]],rep(NA,jags.fit$BUGSoutput$n.sims-max(crosstab.pred_list_spec.High$date$sim_id))))
          names(lambda.tmp)[k] <- "PoSSUM MrP"
        }
        # for every pollster
        for (i in 1:N) {
          # go through each poll
          if (jags.data$pollster_id[i] != k) next 
          # if the poll belongs to this pollster
          # Check if the choice is amongst those in the choice set for the poll
          if (j.id %in% jags.data[[paste0('J_prime_labs', i)]] | j.id =='net') {
            # if so, get the simulations and bind them in a matrix to other polls fro that pollster 
            
            lambda.tmp[[k]] <- cbind(
              lambda.tmp[[k]],
              jags.fit$BUGSoutput$sims.list[[paste0('lambda', i)]][, which(jags.data[[paste0('J_prime_labs', i)]] == j.id)]
            )
            
          }
        }
        # assign a name to the pollster in the list
        names(lambda.tmp)[k] <- jags.data$pollster_labs[k]
      }
      
        lambda.tmp <- lambda.tmp[!grepl('PoSSUM',names(lambda.tmp)) | names(lambda.tmp)=="PoSSUM MrP"]
      
      # take the average per plloster and bind into a matrix - useful for ordering 
      pi_hat.tmp <- sapply(lambda.tmp,function(x){if(length(dim(x)[2])==0){rep(NA,jags.fit$BUGSoutput$n.sims)}else{rowMeans(x)}})
      
      # get map for each pollster
      pi_hat.point <- apply(pi_hat.tmp,2,function(x){if(all(is.na(x))){NA}else{ifelse(all(x==0|is.na(x)),0,median(x,na.rm=T))}})
      
      # initialise plot
      xlim.max = dim(pi_hat.tmp)[2]
      ylim.min = min(c(pi_hat.tmp,unlist(lambda.tmp)),na.rm=TRUE)
      ylim.max = max(c(pi_hat.tmp,unlist(lambda.tmp)),na.rm=TRUE)
      if(THIRD_PARTIES == FALSE){ylim.min = 0.35; ylim.max = 0.6}
      grid.v.every = 1
      grid.h.every = 0.01
      ref = mean(unlist(lambda.tmp[!grepl('PoSSUM',names(lambda.tmp))]),na.rm=T)
      
      plot(
        y = seq(-1,1,length.out = xlim.max),
        x = 1:(xlim.max),
        ylim = c(ylim.min,ylim.max),
        xlim= c(0.5,xlim.max + 0.5),
        pch = NA,
        xaxt = 'n',
        xlab = '',
        ylab = '',
        main = main
      )
      abline(
        v =  seq(0,xlim.max,by = grid.v.every),
        h =  seq(-1,1,by = grid.h.every),
        col = adjustcolor(col = 'darkgrey',0.15)
      )
      
      # plot each poll independently on the grid
      for(k in 1:xlim.max){
        if(is.null(lambda.tmp[[k]])){next}
        # wheere should we plot it ? 
        at = which(names(lambda.tmp)[order(pi_hat.point)]==names(lambda.tmp)[k])
        for(h in 1:dim(lambda.tmp[[k]])[2]){
          x = lambda.tmp[[k]][,h]
          if(all(is.na(x))){next}
          col = 
            ifelse(
              grepl('PoSSUM',names(lambda.tmp)[k]),
              adjustcolor(cols.dark[j.id],0.75),
              adjustcolor(cols.light[j.id],0.5)
            )
          plot.effect_violin(
            x = x,
            var.levels = '',
            at = at,
            cols.light = col,
            horizontal = FALSE,
            reference = ref,
            las = 3,
            cex.axis = 0.5,
            ylim = c(ylim.min,ylim.max),
            xlim= c(0.5,xlim.max + 0.5),
            add = TRUE
          )
        } }
      
      axis(side = 1,
           at = 1:(xlim.max),
           labels = gsub('PoSSUM MrP','',names(lambda.tmp)[order(pi_hat.point)]),
           las = 3)
      for(ax in c("PoSSUM MrP")){
        axis(side = 1,
             at = 1:(xlim.max),
             labels = ifelse(names(lambda.tmp)[order(pi_hat.point)]==ax,ax, ""),
             las = 3,
             font = 2)
      }


      outer.main <- paste0('Fieldwork dates span ',gsub('\\-','\\/',gsub('2024-','',paste0(unlist(strsplit(date,'_')),collapse = ' to '))))

      
      for(k in names(lambda.tmp)){
        if(!k %in% c('PoSSUM MrP') ){
          y.obs <- y[which(crosstab.marginal$pollster==k),]
          pi.obs <- y.obs
          if(!is.null(dim(y.obs)[2])){
            for(l in 1:dim(y.obs)[1]){
              pi.obs[l,] <- y.obs[l,] / sum(y.obs[l,],na.rm=T)
            }
            
            
            pi.obs <- pi.obs[,which(jags.data$J_labs==j.id)]
            
            
          }else{
            
            
            pi.obs <- y.obs/ sum(y.obs,na.rm=T)
            pi.obs <- pi.obs[which(jags.data$J_labs==j.id)]
            
          }
          points(
            x = rep(which(names(lambda.tmp)[order(pi_hat.point)]==k),
                    sum(crosstab.marginal$pollster==k)),
            y = pi.obs,
            pch = 4,
            col = 'black'
          )
        }
      }
    }
    mtext(side = 3,text = outer.main,outer = TRUE,cex = 1.25)
    
    dev.off()
  } 
  
}

}

dev.off()


